package com.yexx.limit.intercept;

import com.yexx.limit.RedisLimit;
import com.yexx.limit.annotation.RequestLimit;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;


/**
 * Redis请求限流 SpringMVC拦截器
 * 借助于SpringControllerLimit实现
 *
 * @author zuomin(MylesZelic @ outlook.com)
 * @date 2020/04/11 18:33
 */
@Slf4j
public class SpringMvcIntercept extends HandlerInterceptorAdapter {

    private static String ACCESS_DENIED = "{\"code\": %s, \"message\": \"%s\"}";

    private RedisLimit redisLimit;

    public SpringMvcIntercept(RedisLimit redisLimit) {
        this.redisLimit = redisLimit;
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {

        if (redisLimit == null) {
            throw new NullPointerException("redisLimit is null");
        }
        if (handler instanceof HandlerMethod) {
            HandlerMethod method = (HandlerMethod) handler;
            RequestLimit annotation = method.getMethodAnnotation(RequestLimit.class);
            if (annotation == null) {
                //skip
                return true;
            }
            boolean limit = redisLimit.limit(assemblyPrefix(method));
            if (!limit) {
                log.warn("path : {} : {}", request.getRequestURI(), annotation.errorMsg());
//                response.sendError(annotation.errorCode(), annotation.errorMsg());
                response.setContentType("application/json;charset=utf-8");
                response.setCharacterEncoding("utf-8");
                try (PrintWriter writer = response.getWriter()) {
                    writer.write(String.format(ACCESS_DENIED, annotation.errorCode(), annotation.errorMsg()));
                    writer.flush();
                }
                return false;
            }
        }

        return true;

    }

    /**
     * assemblyPrefix
     *
     * @param method method
     */
    private String assemblyPrefix(HandlerMethod method) {
        String[] paths = method.getBean().getClass().getName().split("[.]");
        return paths[paths.length - 1] + ":" + method.getMethod().getName();
    }
}
